{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 4: Quantization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.utils import to_categorical\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "seed = 0\n",
    "np.random.seed(seed)\n",
    "import tensorflow as tf\n",
    "tf.random.set_seed(seed)\n",
    "import os\n",
    "os.environ['PATH'] = '/opt/Xilinx/Vivado/2019.2/bin:' + os.environ['PATH']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetch the jet tagging dataset from Open ML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_val = np.load('X_train_val.npy')\n",
    "X_test = np.load('X_test.npy')\n",
    "y_train_val = np.load('y_train_val.npy')\n",
    "y_test = np.load('y_test.npy')\n",
    "classes = np.load('classes.npy', allow_pickle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct a model\n",
    "This time we're going to use QKeras layers.\n",
    "QKeras is \"Quantized Keras\" for deep heterogeneous quantization of ML models.\n",
    "\n",
    "https://github.com/google/qkeras\n",
    "\n",
    "It is maintained by Google and we recently added support for QKeras model to hls4ml."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.models import Sequential\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.regularizers import l1\n",
    "from callbacks import all_callbacks\n",
    "from tensorflow.keras.layers import Activation\n",
    "from qkeras.qlayers import QDense, QActivation\n",
    "from qkeras.quantizers import quantized_bits, quantized_relu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We're using `QDense` layer instead of `Dense`, and `QActivation` instead of `Activation`. We're also specifying `kernel_quantizer = quantized_bits(6,0,0)`. This will use 6-bits (of which 0 are integer) for the weights. We also use the same quantization for the biases, and `quantized_relu(6)` for 6-bit ReLU activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(QDense(64, input_shape=(16,), name='fc1',\n",
    "                 kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),\n",
    "                 kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n",
    "model.add(QActivation(activation=quantized_relu(6), name='relu1'))\n",
    "model.add(QDense(32, name='fc2',\n",
    "                 kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),\n",
    "                 kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n",
    "model.add(QActivation(activation=quantized_relu(6), name='relu2'))\n",
    "model.add(QDense(32, name='fc3',\n",
    "                 kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),\n",
    "                 kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n",
    "model.add(QActivation(activation=quantized_relu(6), name='relu3'))\n",
    "model.add(QDense(5, name='output',\n",
    "                 kernel_quantizer=quantized_bits(6,0,alpha=1), bias_quantizer=quantized_bits(6,0,alpha=1),\n",
    "                 kernel_initializer='lecun_uniform', kernel_regularizer=l1(0.0001)))\n",
    "model.add(Activation(activation='softmax', name='softmax'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train sparse\n",
    "Let's train with model sparsity again, since QKeras layers are prunable."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow_model_optimization.python.core.sparsity.keras import prune, pruning_callbacks, pruning_schedule\n",
    "from tensorflow_model_optimization.sparsity.keras import strip_pruning\n",
    "pruning_params = {\"pruning_schedule\" : pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)}\n",
    "model = prune.prune_low_magnitude(model, **pruning_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model\n",
    "We'll use the same settings as the model for part 1: Adam optimizer with categorical crossentropy loss.\n",
    "The callbacks will decay the learning rate and save the model into a directory 'model_2'\n",
    "The model isn't very complex, so this should just take a few minutes even on the CPU.\n",
    "If you've restarted the notebook kernel after training once, set `train = False` to load the trained model rather than training again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = True\n",
    "if train:\n",
    "    adam = Adam(lr=0.0001)\n",
    "    model.compile(optimizer=adam, loss=['categorical_crossentropy'], metrics=['accuracy'])\n",
    "    callbacks = all_callbacks(stop_patience = 1000,\n",
    "                              lr_factor = 0.5,\n",
    "                              lr_patience = 10,\n",
    "                              lr_epsilon = 0.000001,\n",
    "                              lr_cooldown = 2,\n",
    "                              lr_minimum = 0.0000001,\n",
    "                              outputDir = 'model_3')\n",
    "    callbacks.callbacks.append(pruning_callbacks.UpdatePruningStep())\n",
    "    model.fit(X_train_val, y_train_val, batch_size=1024,\n",
    "              epochs=30, validation_split=0.25, shuffle=True,\n",
    "              callbacks = callbacks.callbacks)\n",
    "    # Save the model again but with the pruning 'stripped' to use the regular layer types\n",
    "    model = strip_pruning(model)\n",
    "    model.save('model_3/KERAS_check_best_model.h5')\n",
    "else:\n",
    "    from tensorflow.keras.models import load_model\n",
    "    from qkeras.utils import _add_supported_quantized_objects\n",
    "    co = {}\n",
    "    _add_supported_quantized_objects(co)\n",
    "    model = load_model('model_3/KERAS_check_best_model.h5', custom_objects=co)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check performance\n",
    "How does this model which was trained using 6-bits, and 75% sparsity model compare against the original model? Let's report the accuracy and make a ROC curve. The quantized, pruned model is shown with solid lines, the unpruned model from part 1 is shown with dashed lines.\n",
    "\n",
    "\n",
    "We should also check that hls4ml can respect the choice to use 6-bits throughout the model, and match the accuracy. We'll generate a configuration from this Quantized model, and plot its performance as the dotted line.\n",
    "The generated configuration is printed out. You'll notice that it uses 7 bits for the type, but we specified 6!? That's just because QKeras doesn't count the sign-bit when we specify the number of bits, so the type that actually gets used needs 1 more.\n",
    "\n",
    "We also use the `OutputRoundingSaturationMode` optimizer pass of `hls4ml` to set the Activation layers to round, rather than truncate, the cast. This is important for getting good model accuracy when using small bit precision activations. And we'll set a different data type for the tables used in the Softmax, just for a bit of extra performance.\n",
    "\n",
    "\n",
    "**Make sure you've trained the model from part 1**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import hls4ml\n",
    "import plotting\n",
    "hls4ml.model.optimizer.OutputRoundingSaturationMode.layers = ['Activation']\n",
    "hls4ml.model.optimizer.OutputRoundingSaturationMode.rounding_mode = 'AP_RND'\n",
    "hls4ml.model.optimizer.OutputRoundingSaturationMode.saturation_mode = 'AP_SAT'\n",
    "\n",
    "config = hls4ml.utils.config_from_keras_model(model, granularity='name')\n",
    "config['LayerName']['softmax']['exp_table_t'] = 'ap_fixed<18,8>'\n",
    "config['LayerName']['softmax']['inv_table_t'] = 'ap_fixed<18,4>'\n",
    "print(\"-----------------------------------\")\n",
    "plotting.print_dict(config)\n",
    "print(\"-----------------------------------\")\n",
    "hls_model = hls4ml.converters.convert_from_keras_model(model,\n",
    "                                                       hls_config=config,\n",
    "                                                       output_dir='model_3/hls4ml_prj',\n",
    "                                                       fpga_part='xcu250-figd2104-2L-e')\n",
    "hls_model.compile()\n",
    "\n",
    "y_qkeras = model.predict(X_test)\n",
    "y_hls = hls_model.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.models import load_model\n",
    "\n",
    "model_ref = load_model('model_1/KERAS_check_best_model.h5')\n",
    "y_ref = model_ref.predict(X_test)\n",
    "\n",
    "print(\"Accuracy baseline:  {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_ref, axis=1))))\n",
    "print(\"Accuracy pruned, quantized: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_qkeras, axis=1))))\n",
    "print(\"Accuracy hls4ml: {}\".format(accuracy_score(np.argmax(y_test, axis=1), np.argmax(y_hls, axis=1))))\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(9, 9))\n",
    "_ = plotting.makeRoc(y_test, y_ref, classes)\n",
    "plt.gca().set_prop_cycle(None) # reset the colors\n",
    "_ = plotting.makeRoc(y_test, y_qkeras, classes, linestyle='--')\n",
    "plt.gca().set_prop_cycle(None) # reset the colors\n",
    "_ = plotting.makeRoc(y_test, y_hls, classes, linestyle=':')\n",
    "\n",
    "from matplotlib.lines import Line2D\n",
    "lines = [Line2D([0], [0], ls='-'),\n",
    "         Line2D([0], [0], ls='--'),\n",
    "         Line2D([0], [0], ls=':')]\n",
    "from matplotlib.legend import Legend\n",
    "leg = Legend(ax, lines, labels=['baseline', 'pruned, quantized', 'hls4ml'],\n",
    "            loc='lower right', frameon=False)\n",
    "ax.add_artist(leg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Synthesize\n",
    "Now let's synthesize this quantized, pruned model.\n",
    "\n",
    "**The synthesis will take a while**\n",
    "\n",
    "While the C-Synthesis is running, we can monitor the progress looking at the log file by opening a terminal from the notebook home, and executing:\n",
    "\n",
    "`tail -f model_3/hls4ml_prj/vivado_hls.log`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls_model.build(csim=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check the reports\n",
    "Print out the reports generated by Vivado HLS. Pay attention to the Utilization Estimates' section in particular this time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls4ml.report.read_vivado_report('model_3/hls4ml_prj')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print the report for the model trained in part 1. Now, compared to the model from part 1, this model has been trained with low-precision quantization, and 75% pruning. You should be able to see that we have saved a lot of resource compared to where we started in part 1. At the same time, referring to the ROC curve above, the model performance is pretty much identical even with this drastic compression!\n",
    "\n",
    "**Note you need to have trained and synthesized the model from part 1**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls4ml.report.read_vivado_report('model_1/hls4ml_prj')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Print the report for the model trained in part 3. Both these models were trained with 75% sparsity, but the new model uses 6-bit precision as well. You can see how Vivado HLS has moved multiplication operations from DSPs into LUTs, reducing the \"critical\" resource usage.\n",
    "\n",
    "**Note you need to have trained and synthesized the model from part 3**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hls4ml.report.read_vivado_report('model_2/hls4ml_prj')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NB\n",
    "Note as well that the Vivado HLS resource estimates tend to _overestimate_ LUTs, while generally estimating the DSPs correctly. Running the subsequent stages of FPGA compilation reveals the more realistic resource usage, You can run the next step, 'logic synthesis' with `hls_model.build(synth=True, vsynth=True)`, but we skipped it in this tutorial in the interest of time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
